{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**This notebook shows how to run the inference in the training-time two-view settings on the validation or training set of MegaDepth to visualize the training metrics and losses.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from pixloc import run_Aachen\n",
    "from pixloc.pixlib.datasets.megadepth import MegaDepth\n",
    "from pixloc.pixlib.utils.tensor import batch_to_device, map_tensor\n",
    "from pixloc.pixlib.utils.tools import set_seed\n",
    "from pixloc.pixlib.utils.experiments import load_experiment\n",
    "from pixloc.visualization.viz_2d import (\n",
    "    plot_images, plot_keypoints, plot_matches, cm_RdGn,\n",
    "    features_to_RGB, add_text)\n",
    "\n",
    "torch.set_grad_enabled(False);\n",
    "mpl.rcParams['image.interpolation'] = 'bilinear'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create a validation or training dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[09/24/2021 16:58:00 pixloc.pixlib.datasets.base_dataset INFO] Creating dataset MegaDepth\n",
      "[09/24/2021 16:58:00 pixloc.pixlib.datasets.megadepth INFO] Sampling new images or pairs with seed 1\n",
      " 44%|██████████████████████▉                             | 34/77 [00:03<00:04,  9.24it/s][09/24/2021 16:58:04 pixloc.pixlib.datasets.megadepth WARNING] Scene 0209 does not have an info file\n",
      "100%|████████████████████████████████████████████████████| 77/77 [00:06<00:00, 11.66it/s]\n"
     ]
    }
   ],
   "source": [
    "conf = {\n",
    "    'min_overlap': 0.4,\n",
    "    'max_overlap': 1.0,\n",
    "    'max_num_points3D': 512,\n",
    "    'force_num_points3D': True,\n",
    "    \n",
    "    'resize': 512,\n",
    "    'resize_by': 'min',\n",
    "    'crop': 512,\n",
    "    'optimal_crop': True,\n",
    "    \n",
    "    'init_pose': [0.75, 1.],\n",
    "#     'init_pose': 'max_error',\n",
    "#     'init_pose_max_error': 4,\n",
    "#     'init_pose_num_samples': 50,\n",
    "    \n",
    "    'batch_size': 1,\n",
    "    'seed': 1,\n",
    "    'num_workers': 0,\n",
    "}\n",
    "loader = MegaDepth(conf).get_data_loader('val', shuffle=True)\n",
    "orig_items = loader.dataset.items"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the training experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Name of the example experiment. Replace with your own training experiment.\n",
    "exp = run_Aachen.experiment\n",
    "device = 'cuda'\n",
    "conf = {\n",
    "    'optimizer': {'num_iters': 20,},\n",
    "}\n",
    "refiner = load_experiment(exp, conf).to(device)\n",
    "print(OmegaConf.to_yaml(refiner.conf))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run on a few examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Reference image: red/green = reprojections of 3D points not/visible in the query at the ground truth pose\n",
    "- Query image: red/blue/green = reprojections of 3D points at the initial/final/GT poses\n",
    "- ΔP/ΔR/Δt are final errors in terms of 2D reprojections, rotation, and translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(7)\n",
    "for _, data in zip(range(5), loader):\n",
    "    data_ = batch_to_device(data, device)\n",
    "    pred_ = refiner(data_)\n",
    "    pred = map_tensor(pred_, lambda x: x[0].cpu())\n",
    "    data = map_tensor(data, lambda x: x[0].cpu())\n",
    "    cam_q = data['query']['camera']\n",
    "    p3D_r = data['ref']['points3D']\n",
    "        \n",
    "    p2D_r, valid_r = data['ref']['camera'].world2image(p3D_r)\n",
    "    p2D_q_gt, valid_q = cam_q.world2image(data['T_r2q_gt'] * p3D_r)\n",
    "    p2D_q_init, _ = cam_q.world2image(data['T_r2q_init'] * p3D_r)\n",
    "    p2D_q_opt, _ = cam_q.world2image(pred['T_r2q_opt'][-1] * p3D_r)\n",
    "    valid = valid_q & valid_r\n",
    "    \n",
    "    losses = refiner.loss(pred_, data_)\n",
    "    mets = refiner.metrics(pred_, data_)\n",
    "    errP = f\"ΔP {losses['reprojection_error/init'].item():.2f} -> {losses['reprojection_error'].item():.3f} px; \"\n",
    "    errR = f\"ΔR {mets['R_error/init'].item():.2f} -> {mets['R_error'].item():.3f} deg; \"\n",
    "    errt = f\"Δt {mets['t_error/init'].item():.2f} -> {mets['t_error'].item():.3f} %m\"\n",
    "    print(errP, errR, errt)\n",
    "\n",
    "    imr, imq = data['ref']['image'].permute(1, 2, 0), data['query']['image'].permute(1, 2, 0)\n",
    "    plot_images([imr, imq],titles=[(data['scene'][0], valid_r.sum().item(), valid_q.sum().item()), errP+'; '+errR])\n",
    "    plot_keypoints([p2D_r[valid_r], p2D_q_gt[valid]], colors=[cm_RdGn(valid[valid_r]), 'lime'])\n",
    "    plot_keypoints([np.empty((0, 2)), p2D_q_init[valid]], colors='red')\n",
    "    plot_keypoints([np.empty((0, 2)), p2D_q_opt[valid]], colors='blue')\n",
    "    add_text(0, 'reference')\n",
    "    add_text(1, 'query')\n",
    "\n",
    "    continue\n",
    "    for i, (F0, F1) in enumerate(zip(pred['ref']['feature_maps'], pred['query']['feature_maps'])):\n",
    "        C_r, C_q = pred['ref']['confidences'][i][0], pred['query']['confidences'][i][0]\n",
    "        plot_images([C_r, C_q], cmaps=mpl.cm.turbo)\n",
    "        add_text(0, f'Level {i}')\n",
    "            \n",
    "        axes = plt.gcf().axes\n",
    "        axes[0].imshow(imr, alpha=0.2, extent=axes[0].images[0]._extent)\n",
    "        axes[1].imshow(imq, alpha=0.2, extent=axes[1].images[0]._extent)\n",
    "        plot_images(features_to_RGB(F0.numpy(), F1.numpy(), skip=1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
